Split

将一个张量(Tensor)沿着指定的轴(axis)拆分为多个子张量。子张量在指定轴上的大小由 split_sizes 数组决定。

\[\text{input.shape} = [d_0, d_1, \dots, d_{axis}, \dots, d_{n-1}]\]
\[\text{对于第 } j \text{ 个输出: } output_j\text{.shape} = [d_0, d_1, \dots, split\_sizes[j], \dots, d_{n-1}]\]
输入:
  • input - 输入数据起始地址。

  • axis - 指定拆分的维度轴。

  • input_shape - 输入张量的形状数组地址。

  • input_ndim - 输入张量的维度数。

  • num_split - 拆分出的子张量个数。

  • split_sizes - 一个数组,包含每个子张量在拆分轴上的长度。

  • core_mask(int, 可选) - 核掩码(仅适用于共享存储版本)。

输出:
  • outputs - 指针数组地址,其中每个元素指向一个子张量的存储地址。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持 int8, int16, int32, fp32, fp64, cplx64, cplx128

  • MT7004 支持 fp16, fp32, int16, int32, cplx64

  • split_sizes 的元素之和必须等于输入张量在 axis 维度的长度。

  • 对于复数类型(cplx64 / cplx128),拆分逻辑与实数一致,但需注意地址偏移按复数对计算。

共享存储版本:

void i8_split_s(int8_t *input, int8_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
void i16_split_s(int16_t *input, int16_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
void i32_split_s(int32_t *input, int32_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
void hp_split_s(half *input, half *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
void fp_split_s(float *input, float *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
void dp_split_s(double *input, double *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
void c64_split_s(float *input, float *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)
void c128_split_s(double *input, double *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes, int core_mask)

C调用示例:

 1//FT78NE示例(共享存储)
 2#include <stdio.h>
 3#include "78NE/utils.h"
 4
 5int main(int argc, char* argv[]) {
 6    float *input = (float *)0xA0000000;
 7    float *out0 = (float *)0xB0000000;
 8    float *out1 = (float *)0xB1000000;
 9    float *outputs[] = { out0, out1 };
10    int input_shape[] = { 2, 10, 4 };
11    int split_sizes[] = { 6, 4 };
12    int axis = 1;
13    int input_ndim = 3;
14    int num_split = 2;
15    int core_mask = 0b1011;
16
17    fp_split_s(input, outputs, axis, input_shape, input_ndim, num_split, split_sizes, core_mask);
18    return 0;
19}

私有存储版本:

void i8_split_p(int8_t *input, int8_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
void i16_split_p(int16_t *input, int16_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
void i32_split_p(int32_t *input, int32_t *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
void hp_split_p(half *input, half *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
void fp_split_p(float *input, float *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
void dp_split_p(double *input, double *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
void c64_split_p(float *input, float *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)
void c128_split_p(double *input, double *outputs[], int axis, int *input_shape, int input_ndim, int num_split, int *split_sizes)

C调用示例:

 1//MT7004 示例
 2#include <stdio.h>
 3
 4int main(int argc, char* argv[]) {
 5    float *input = (float *)0x10000000; // 私有存储地址
 6    float *out0 = (float *)0x10010000;
 7    float *out1 = (float *)0x10020000;
 8    float *outputs[] = { out0, out1 };
 9    int input_shape[] = { 20, 10 };
10    int split_sizes[] = { 10, 10 };
11    int axis = 0;
12    fp_split_p(input, outputs, axis, input_shape, 2, 2, split_sizes);
13    return 0;
14}